package com.nbsaas.nbmall.web.interceptor;

import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.cache.CacheBuilder;
import com.google.common.cache.CacheLoader;
import com.google.common.cache.LoadingCache;
import com.google.common.util.concurrent.RateLimiter;
import com.haoxuer.discover.rest.base.ResponseObject;
import org.apache.commons.lang3.StringUtils;
import org.springframework.web.servlet.handler.HandlerInterceptorAdapter;

import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;

public class LimitInterceptor extends HandlerInterceptorAdapter {

    public static final int REQUEST_COUNT = 2000;
    /*** 网站每秒最大支持访问请求数量 */
    private static final RateLimiter rateLimiter = RateLimiter.create(REQUEST_COUNT);


    @Override
    public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler)
            throws Exception {
        String ip = getIpAddress(request);
        RateLimiter limiter = getIPLimiter(ip);
        if (!limiter.tryAcquire()) {
            ResponseObject result = new ResponseObject();
            result.setCode(1001);
            result.setMsg("ip访问过快");
            handleNone(response,result);
            return false;
        }
        if (!rateLimiter.tryAcquire()) {
            ResponseObject result = new ResponseObject();
            result.setCode(1000);
            result.setMsg("服务器压力过大");
            handleNone(response,result);
            return false;
        }
        return true;
    }

    public static String getIpAddress(HttpServletRequest request) throws Exception {
        String ip = request.getHeader("X-Real-IP");
        if (!StringUtils.isBlank(ip) && !"unknown".equalsIgnoreCase(ip)) {
            return ip;
        }
        ip = request.getHeader("X-Forwarded-For");
        if (!StringUtils.isBlank(ip) && !"unknown".equalsIgnoreCase(ip)) {
            // 多次反向代理后会有多个IP值，第一个为真实IP。
            int index = ip.indexOf(',');
            if (index != -1) {
                return ip.substring(0, index);
            } else {
                return ip;
            }
        } else {
            return request.getRemoteAddr();
        }
    }

    public RateLimiter getIPLimiter(String ipAddress) throws ExecutionException {
        return ipRequestCaches.get(ipAddress);
    }

    LoadingCache<String, RateLimiter> ipRequestCaches = CacheBuilder.newBuilder()
            .maximumSize(5000)// 设置缓存个数
            .expireAfterWrite(20, TimeUnit.MINUTES)
            .build(new CacheLoader<String, RateLimiter>() {
                @Override
                public RateLimiter load(String s) throws Exception {
                    return RateLimiter.create(500);// 新的IP初始化 (限流每秒0.1个令牌响应,即10s一个令牌)
                }
            });

    public void handleNone(HttpServletResponse response,ResponseObject object) throws IOException {
        response.setCharacterEncoding("UTF-8");
        response.setContentType("application/json; charset=utf-8");
        ObjectMapper objectMapper = new ObjectMapper();
        response.getWriter().println(objectMapper.writeValueAsString(object));
        return;
    }
}
